{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Classification Trees"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The construction of a classification tree is very similar to that of a regression tree. For a fuller description of the code below, please see the regression tree code on the previous page. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Import packages\n",
    "import numpy as np \n",
    "from itertools import combinations\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "## Load data\n",
    "penguins = sns.load_dataset('penguins')\n",
    "penguins.dropna(inplace = True)\n",
    "X = np.array(penguins.drop(columns = 'species'))\n",
    "y = np.array(penguins['species'])\n",
    "\n",
    "## Train-test split\n",
    "np.random.seed(123)\n",
    "test_frac = 0.25\n",
    "test_size = int(len(y)*test_frac)\n",
    "test_idxs = np.random.choice(np.arange(len(y)), test_size, replace = False)\n",
    "X_train = np.delete(X, test_idxs, 0)\n",
    "y_train = np.delete(y, test_idxs, 0)\n",
    "X_test = X[test_idxs]\n",
    "y_test = y[test_idxs]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will build our classification tree on the {doc}`penguins </content/appendix/data>` dataset from `seaborn`. This dataset has a categorical target variable—penguin breed—with both quantitative and categorical predictors. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Helper Functions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's first create our loss functions. The Gini index and cross-entropy calculate the loss for a single node while the `split_loss()` function creates the weighted loss of a split. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Loss Functions\n",
    "def gini_index(y):\n",
    "    size = len(y)\n",
    "    classes, counts = np.unique(y, return_counts = True)\n",
    "    pmk = counts/size\n",
    "    return np.sum(pmk*(1-pmk))\n",
    "     \n",
    "def cross_entropy(y):\n",
    "    size = len(y)\n",
    "    classes, counts = np.unique(y, return_counts = True)\n",
    "    pmk = counts/size\n",
    "    return -np.sum(pmk*np.log2(pmk))\n",
    "\n",
    "def split_loss(child1, child2, loss = cross_entropy):\n",
    "    return (len(child1)*loss(child1) + len(child2)*loss(child2))/(len(child1) + len(child2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, let's define a few miscellaneous helper functions. As in the regression tree construction, `all_rows_equal()` checks if all of a bud's rows (observations) are equal across all predictors. If this is the case, this bud will not be split and instead becomes a terminal leaf. The second function, `possible_splits()`, returns all possible ways to divide the classes in a categorical predictor into two. Specifically, it returns all possible sets of values which can be used to funnel observations into the \"left\" child node. An example is given below for a predictor with four categories, $a$ through $d$. The set $\\{a, b\\}$, for instance, would imply observations where that predictor equals $a$ or $b$ go to the left child and other observations go to the right child. (Note that this function requires the `itertools` package). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('a',),\n",
       " ('b',),\n",
       " ('c',),\n",
       " ('d',),\n",
       " ('a', 'b'),\n",
       " ('a', 'c'),\n",
       " ('a', 'd'),\n",
       " ('b', 'c'),\n",
       " ('b', 'd'),\n",
       " ('c', 'd')]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "## Helper Functions\n",
    "def all_rows_equal(X):\n",
    "    return (X == X[0]).all()\n",
    "\n",
    "def possible_splits(x):\n",
    "    L_values = []\n",
    "    for i in range(1, int(np.floor(len(x)/2)) + 1):\n",
    "        L_values.extend(list(combinations(x, i)))\n",
    "    return L_values\n",
    "\n",
    "possible_splits(['a','b','c','d'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Helper Classes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we define two classes to help our main decision tree classifier. These classes are essentially identical to those discussed in the regression tree page. The only difference is the loss function used to evaluate a split."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Node:\n",
    "    \n",
    "    def __init__(self, Xsub, ysub, ID, obs, depth = 0, parent_ID = None, leaf = True):\n",
    "        self.Xsub = Xsub\n",
    "        self.ysub = ysub\n",
    "        self.ID = ID\n",
    "        self.obs = obs\n",
    "        self.size = len(ysub)\n",
    "        self.depth = depth\n",
    "        self.parent_ID = parent_ID\n",
    "        self.leaf = leaf\n",
    "        \n",
    "\n",
    "class Splitter:\n",
    "    \n",
    "    def __init__(self):\n",
    "        self.loss = np.inf\n",
    "        self.no_split = True\n",
    "        \n",
    "    def _replace_split(self, Xsub_d, loss, d, dtype = 'quant', t = None, L_values = None):\n",
    "        self.loss = loss\n",
    "        self.d = d\n",
    "        self.dtype = dtype\n",
    "        self.t = t\n",
    "        self.L_values = L_values\n",
    "        self.no_split = False\n",
    "        if dtype == 'quant':\n",
    "            self.L_obs = self.obs[Xsub_d <= t]\n",
    "            self.R_obs = self.obs[Xsub_d > t]\n",
    "        else:\n",
    "            self.L_obs = self.obs[np.isin(Xsub_d, L_values)]\n",
    "            self.R_obs = self.obs[~np.isin(Xsub_d, L_values)]\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Main Class"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we create the main class for our classification tree. This again is essentially identical to the regression tree class. In addition to differing in the loss function used to evaluate splits, this tree differs from the regression tree in how it forms predictions. In regression trees, the fitted value for a test observation was the average target variable of the training observations landing in the same leaf. In the classification tree, since our target variable is categorical, we instead use the most common class among training observations landing in the same leaf. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DecisionTreeClassifier:\n",
    "    \n",
    "    #############################\n",
    "    ######## 1. TRAINING ########\n",
    "    #############################\n",
    "    \n",
    "    ######### FIT ##########\n",
    "    def fit(self, X, y, loss_func = cross_entropy, max_depth = 100, min_size = 2, C = None):\n",
    "        \n",
    "        ## Add data\n",
    "        self.X = X\n",
    "        self.y = y\n",
    "        self.N, self.D = self.X.shape\n",
    "        dtypes = [np.array(list(self.X[:,d])).dtype for d in range(self.D)]\n",
    "        self.dtypes = ['quant' if (dtype == float or dtype == int) else 'cat' for dtype in dtypes]\n",
    "\n",
    "        ## Add model parameters\n",
    "        self.loss_func = loss_func\n",
    "        self.max_depth = max_depth\n",
    "        self.min_size = min_size\n",
    "        self.C = C\n",
    "        \n",
    "        ## Initialize nodes\n",
    "        self.nodes_dict = {}\n",
    "        self.current_ID = 0\n",
    "        initial_node = Node(Xsub = X, ysub = y, ID = self.current_ID, obs = np.arange(self.N), parent_ID = None)\n",
    "        self.nodes_dict[self.current_ID] = initial_node\n",
    "        self.current_ID += 1\n",
    "        \n",
    "        # Build\n",
    "        self._build()\n",
    "\n",
    "    ###### BUILD TREE ######\n",
    "    def _build(self):\n",
    "        \n",
    "        eligible_buds = self.nodes_dict \n",
    "        for layer in range(self.max_depth):\n",
    "            \n",
    "            ## Find eligible nodes for layer iteration\n",
    "            eligible_buds = {ID:node for (ID, node) in self.nodes_dict.items() if \n",
    "                                (node.leaf == True) &\n",
    "                                (node.size >= self.min_size) & \n",
    "                                (~all_rows_equal(node.Xsub)) &\n",
    "                                (len(np.unique(node.ysub)) > 1)}\n",
    "            if len(eligible_buds) == 0:\n",
    "                break\n",
    "            \n",
    "            ## split each eligible parent\n",
    "            for ID, bud in eligible_buds.items():\n",
    "                                \n",
    "                ## Find split\n",
    "                self._find_split(bud)\n",
    "                \n",
    "                ## Make split\n",
    "                if not self.splitter.no_split:\n",
    "                    self._make_split()\n",
    "                \n",
    "    ###### FIND SPLIT ######\n",
    "    def _find_split(self, bud):\n",
    "        \n",
    "        ## Instantiate splitter\n",
    "        splitter = Splitter()\n",
    "        splitter.bud_ID = bud.ID\n",
    "        splitter.obs = bud.obs\n",
    "        \n",
    "        ## For each (eligible) predictor...\n",
    "        if self.C is None:\n",
    "            eligible_predictors = np.arange(self.D)\n",
    "        else:\n",
    "            eligible_predictors = np.random.choice(np.arange(self.D), self.C, replace = False)\n",
    "        for d in sorted(eligible_predictors):\n",
    "            Xsub_d = bud.Xsub[:,d]\n",
    "            dtype = self.dtypes[d]\n",
    "            if len(np.unique(Xsub_d)) == 1:\n",
    "                continue\n",
    "\n",
    "            ## For each value...\n",
    "            if dtype == 'quant':\n",
    "                for t in np.unique(Xsub_d)[:-1]:\n",
    "                    ysub_L = bud.ysub[Xsub_d <= t]\n",
    "                    ysub_R = bud.ysub[Xsub_d > t]\n",
    "                    loss = split_loss(ysub_L, ysub_R, loss = self.loss_func)\n",
    "                    if loss < splitter.loss:\n",
    "                        splitter._replace_split(Xsub_d, loss, d, 'quant', t = t)\n",
    "            else:\n",
    "                for L_values in possible_splits(np.unique(Xsub_d)):\n",
    "                    ysub_L = bud.ysub[np.isin(Xsub_d, L_values)]\n",
    "                    ysub_R = bud.ysub[~np.isin(Xsub_d, L_values)]\n",
    "                    loss = split_loss(ysub_L, ysub_R, loss = self.loss_func)\n",
    "                    if loss < splitter.loss: \n",
    "                        splitter._replace_split(Xsub_d, loss, d, 'cat', L_values = L_values)\n",
    "                        \n",
    "        ## Save splitter\n",
    "        self.splitter = splitter\n",
    "    \n",
    "    ###### MAKE SPLIT ######\n",
    "    def _make_split(self):\n",
    "        \n",
    "        ## Update parent node\n",
    "        parent_node = self.nodes_dict[self.splitter.bud_ID]\n",
    "        parent_node.leaf = False\n",
    "        parent_node.child_L = self.current_ID\n",
    "        parent_node.child_R = self.current_ID + 1\n",
    "        parent_node.d = self.splitter.d\n",
    "        parent_node.dtype = self.splitter.dtype\n",
    "        parent_node.t = self.splitter.t        \n",
    "        parent_node.L_values = self.splitter.L_values\n",
    "        parent_node.L_obs, parent_node.R_obs = self.splitter.L_obs, self.splitter.R_obs\n",
    "        \n",
    "        ## Get X and y data for children\n",
    "        if parent_node.dtype == 'quant':\n",
    "            L_condition = parent_node.Xsub[:,parent_node.d] <= parent_node.t\n",
    "        else:\n",
    "            L_condition = np.isin(parent_node.Xsub[:,parent_node.d], parent_node.L_values)\n",
    "        Xchild_L = parent_node.Xsub[L_condition]\n",
    "        ychild_L = parent_node.ysub[L_condition]\n",
    "        Xchild_R = parent_node.Xsub[~L_condition]\n",
    "        ychild_R = parent_node.ysub[~L_condition]\n",
    "        \n",
    "        ## Create child nodes\n",
    "        child_node_L = Node(Xchild_L, ychild_L, obs = parent_node.L_obs, depth = parent_node.depth + 1,\n",
    "                            ID = self.current_ID, parent_ID = parent_node.ID)\n",
    "        child_node_R = Node(Xchild_R, ychild_R, obs = parent_node.R_obs, depth = parent_node.depth + 1,\n",
    "                            ID = self.current_ID+1, parent_ID = parent_node.ID)\n",
    "        self.nodes_dict[self.current_ID] = child_node_L\n",
    "        self.nodes_dict[self.current_ID + 1] = child_node_R\n",
    "        self.current_ID += 2\n",
    "                \n",
    "            \n",
    "    #############################\n",
    "    ####### 2. PREDICTING #######\n",
    "    #############################\n",
    "    \n",
    "    ###### LEAF MODES ######\n",
    "    def _get_leaf_modes(self):\n",
    "        self.leaf_modes = {}\n",
    "        for node_ID, node in self.nodes_dict.items():\n",
    "            if node.leaf:\n",
    "                values, counts = np.unique(node.ysub, return_counts=True)\n",
    "                self.leaf_modes[node_ID] = values[np.argmax(counts)]\n",
    "    \n",
    "    ####### PREDICT ########\n",
    "    def predict(self, X_test):\n",
    "        \n",
    "        # Calculate leaf modes\n",
    "        self._get_leaf_modes()\n",
    "        \n",
    "        yhat = []\n",
    "        for x in X_test:\n",
    "            node = self.nodes_dict[0] \n",
    "            while not node.leaf:\n",
    "                if node.dtype == 'quant':\n",
    "                    if x[node.d] <= node.t:\n",
    "                        node = self.nodes_dict[node.child_L]\n",
    "                    else:\n",
    "                        node = self.nodes_dict[node.child_R]\n",
    "                else:\n",
    "                    if x[node.d] in node.L_values:\n",
    "                        node = self.nodes_dict[node.child_L]\n",
    "                    else:\n",
    "                        node = self.nodes_dict[node.child_R]\n",
    "            yhat.append(self.leaf_modes[node.ID])\n",
    "        return np.array(yhat)\n",
    "            \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A classificaiton tree is built on the {doc}`penguins </content/appendix/data>` dataset. We evaluate the predictions on a test set and find that roughly 95% of observations are correctly classified."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9518072289156626"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "## Build classifier \n",
    "tree = DecisionTreeClassifier()\n",
    "tree.fit(X_train, y_train, max_depth = 10, min_size = 10)\n",
    "y_test_hat = tree.predict(X_test)\n",
    "\n",
    "## Evaluate on test data\n",
    "np.mean(y_test_hat == y_test)"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Edit Metadata",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
